{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "toc": true
   },
   "source": [
    "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
    "<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"></ul></div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TF Estimator in R <a class=\"tocSkip\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# example from https://tensorflow.rstudio.com/tfestimators/articles/examples/mnist.html\n",
    "\n",
    "library(tensorflow)\n",
    "library(tfestimators)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define hyperparameters -----------\n",
    "\n",
    "batch_size <- 128\n",
    "n_classes <- 10\n",
    "n_steps <- 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data Preparation -----------\n",
    "\n",
    "# initialize data directory\n",
    "data_dir <- \"~/datasets/mnist\"\n",
    "dir.create(data_dir, recursive = TRUE, showWarnings = FALSE)\n",
    "\n",
    "# download the MNIST data sets, and read them into R\n",
    "sources <- list(\n",
    "  \n",
    "  train = list(\n",
    "    x = \"https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz\",\n",
    "    y = \"https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz\"\n",
    "  ),\n",
    "  \n",
    "  test = list(\n",
    "    x = \"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz\",\n",
    "    y = \"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz\"\n",
    "  )\n",
    "  \n",
    ")\n",
    "\n",
    "# read an MNIST file (encoded in IDX format)\n",
    "read_idx <- function(file) {\n",
    "  \n",
    "  # create binary connection to file\n",
    "  conn <- gzfile(file, open = \"rb\")\n",
    "  on.exit(close(conn), add = TRUE)\n",
    "  \n",
    "  # read the magic number as sequence of 4 bytes\n",
    "  magic <- readBin(conn, what = \"raw\", n = 4, endian = \"big\")\n",
    "  ndims <- as.integer(magic[[4]])\n",
    "  \n",
    "  # read the dimensions (32-bit integers)\n",
    "  dims <- readBin(conn, what = \"integer\", n = ndims, endian = \"big\")\n",
    "  \n",
    "  # read the rest in as a raw vector\n",
    "  data <- readBin(conn, what = \"raw\", n = prod(dims), endian = \"big\")\n",
    "  \n",
    "  # convert to an integer vecto\n",
    "  converted <- as.integer(data)\n",
    "  \n",
    "  # return plain vector for 1-dim array\n",
    "  if (length(dims) == 1)\n",
    "    return(converted)\n",
    "  \n",
    "  # wrap 3D data into matrix\n",
    "  matrix(converted, nrow = dims[1], ncol = prod(dims[-1]), byrow = TRUE)\n",
    "}\n",
    "\n",
    "mnist <- rapply(sources, classes = \"character\", how = \"list\", function(url) {\n",
    "  \n",
    "  # download + extract the file at the URL\n",
    "  target <- file.path(data_dir, basename(url))\n",
    "  if (!file.exists(target))\n",
    "    download.file(url, target)\n",
    "  \n",
    "  # read the IDX file\n",
    "  read_idx(target)\n",
    "  \n",
    "})\n",
    "\n",
    "# convert training data intensities to 0-1 range\n",
    "mnist$train$x <- mnist$train$x / 255\n",
    "mnist$test$x <- mnist$test$x / 255\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Model ---------------\n",
    "\n",
    "# construct a linear classifier\n",
    "classifier <- linear_classifier(\n",
    "  feature_columns = feature_columns(\n",
    "    column_numeric(\"x\", shape = shape(784L))\n",
    "  ),\n",
    "  n_classes = n_classes  # 10 digits\n",
    ")\n",
    "\n",
    "# construct an input function generator\n",
    "mnist_input_fn <- function(data, ...) {\n",
    "  input_fn(\n",
    "    data,\n",
    "    response = \"y\",\n",
    "    features = \"x\",\n",
    "    batch_size = batch_size,\n",
    "    ...\n",
    "  )\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training the model --------\n",
    "\n",
    "train(classifier, input_fn = mnist_input_fn(mnist$train), steps = n_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Evaluation completed after 79 steps but 200 steps was specified\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<thead><tr><th scope=col>average_loss</th><th scope=col>loss</th><th scope=col>global_step</th><th scope=col>accuracy</th></tr></thead>\n",
       "<tbody>\n",
       "\t<tr><td>0.35656 </td><td>45.13418</td><td>100     </td><td>0.9057  </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{r|llll}\n",
       " average\\_loss & loss & global\\_step & accuracy\\\\\n",
       "\\hline\n",
       "\t 0.35656  & 45.13418 & 100      & 0.9057  \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/markdown": [
       "\n",
       "average_loss | loss | global_step | accuracy | \n",
       "|---|\n",
       "| 0.35656  | 45.13418 | 100      | 0.9057   | \n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "  average_loss loss     global_step accuracy\n",
       "1 0.35656      45.13418 100         0.9057  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Evaluate the model --------\n",
    "\n",
    "evaluate(classifier, input_fn = mnist_input_fn(mnist$test), steps = 200)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "R 3.4.2",
   "language": "R",
   "name": "ir342"
  },
  "language_info": {
   "codemirror_mode": "r",
   "file_extension": ".r",
   "mimetype": "text/x-r-source",
   "name": "R",
   "pygments_lexer": "r",
   "version": "3.4.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
